#!/usr/bin/env python
from contextlib import contextmanager
import logging
import math
import os
import os.path
import subprocess
import sys

import boto3
import click
import docker
import yaml

import universe

logger = logging.getLogger('cluster')

DEBUG_LOGGING_MAP = {
    0: logging.WARNING,
    1: logging.INFO,
    2: logging.DEBUG
}

VNC_PORT = 5900
REWARDER_PORT = 15900


def get_ports(cli, n):
    """
    returns a list of n ports that are available in the docker cluster
    """
    used_ports = set()

    containers = cli.containers()
    for container in containers:
        for port in container.get('Ports', []):
            used_ports.add(port.get('PublicPort'))

    ports = []
    obtained = 0
    for i in range(5000, 10000):
        if i not in used_ports:
            ports.append(i)
            obtained += 1

        if obtained == n:
            break

    return ports


def start_ssh_tunnel(host, local_port, remote_path, key_path):
    cmd = [
        'ssh', '-L', 'localhost:{}:{}'.format(local_port, remote_path),
        '-o', 'stricthostkeychecking=no', '-o', 'UserKnownHostsFile=/dev/null']
    if key_path:
        cmd += ['-i', key_path]
    cmd.append('ec2-user@{}'.format(host))
    cmd.append('echo tunnel-ready; sleep 3600')

    logger.debug('Starting SSH tunnel: %s', cmd)

    process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
                               stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE)
    return process


@contextmanager
def docker_cli(host, local_port, remote_path='localhost:4000',
               key_path=None):
    ssh_tunnel = start_ssh_tunnel(host, local_port, remote_path, key_path)
    for line in iter(ssh_tunnel.stdout.readline, ''):
        if line.strip() == b'tunnel-ready':
            break
    logger.debug('SSH tunnel ready [pid: %s]', ssh_tunnel.pid)
    try:
        yield docker.Client(base_url='localhost:{}'.format(local_port))
    finally:
        logger.debug('SSH tunnel terminating')
        ssh_tunnel.terminate()
        for line in iter(ssh_tunnel.stderr.readline, ''):
            line = line.strip()
            if not line:
                break
            logger.debug('[SSH tunnel stderr] %s', line)


def get_compose_file(runtime):
    dir_path = os.path.dirname(os.path.realpath(__file__))
    return os.path.join(dir_path, 'gen', runtime, 'docker-compose.yaml')


def build_compose(cli, runtime, n):
    spec = universe.runtime_spec(runtime)
    expose = [VNC_PORT-1, VNC_PORT, REWARDER_PORT-1, REWARDER_PORT]
    usable_ports = get_ports(cli, len(expose) * n)

    output = {
        'version': '2'
    }

    services = {}
    for i in range(n):
        service = {
            'image': spec.image,
            'command': spec.command,
            'cap_add': spec.host_config.get('cap_add', [])
        }
        if spec.host_config.get('ipc_mode'):
            service['ipc'] = spec.host_config['ipc_mode']

        service['cpu_shares'] = int(math.ceil(spec.default_params.get('cpu', 4)))

        service['ports'] = ['{host}:{container}'.format(container=port, host=usable_ports[i*len(expose)+j])
                            for j, port in enumerate(expose)]

        service['labels'] = {
            'universe.runtime': runtime,
            'universe.index': str(i)
        }

        services['{}-{}'.format(runtime, i)] = service

    output['services'] = services
    content = yaml.dump(output)

    filepath = get_compose_file(runtime)

    directory = os.path.dirname(filepath)
    if not os.path.exists(directory):
        logger.info('Creating directory: %s', directory)
        os.makedirs(directory)

    logger.info('Writing compose file to %s', filepath)
    with open(filepath, 'w') as f:
        f.write(content)

    return filepath


def start_compose(cli, filepath):
    subprocess.check_call(['docker-compose', '-H', cli.base_url, '-f', filepath, 'up', '-d', '--remove-orphans'])


def stop_compose(cli, filepath):
    subprocess.check_call(['docker-compose', '-H', cli.base_url, '-f', filepath, 'down', '--remove-orphans'])


class Stack(object):
    def __init__(self, data):
        self.name = data['StackName']

        outputs = {}
        for output in data['Outputs']:
            outputs[output['OutputKey']] = output['OutputValue']
        self.docker_ip = outputs['DockerIP']
        self.worker_asg = outputs['ASGName']


def get_stack(name):
    client = boto3.client('cloudformation')
    response = client.describe_stacks(StackName=name)
    if len(response['Stacks']) == 0:
        raise Exception('Failed to find CloudFormation stack {}'.format(name))

    return Stack(response['Stacks'][0])


def get_worker_instances(stack_name):
    instances = {}

    client = boto3.client('ec2')
    response = client.describe_instances(
        Filters=[
            {
                'Name': 'tag:aws:cloudformation:stack-name',
                'Values': [
                    stack_name,
                ]
            },
            {
                'Name': 'instance-state-name',
                'Values': [
                    'running'
                ]
            }
        ]
    )
    for reservation in response['Reservations']:
        for instance in reservation['Instances']:
            instances[instance['PrivateIpAddress']] = {
                'id': instance['InstanceId'],
                'public_ip': instance.get('PublicIpAddress')
            }

    return instances


def get_runtime_containers(cli, runtime):
    containers_map = {}
    filters = {
        'label': ['universe.runtime={}'.format(runtime)]
    }
    containers = cli.containers(filters=filters)
    for container in containers:
        labels = container['Labels']
        addr = None
        for name in container['Names']:
            if name.startswith('/ip-'):
                addr = name.split('/')[1][3:].replace('-', '.')
        containers_map[labels['com.docker.compose.service']] = {
            'labels': labels,
            'ports': container['Ports'],
            'host': addr
        }

    return containers_map


@click.group()
@click.option('--verbose', '-v',
              help='Sets the debug noise level, specify multiple times for more verbosity.',
              type=click.IntRange(0, 3, clamp=True),
              count=True)
def cli(verbose):
    logger_handler = logging.StreamHandler(sys.stderr)
    logger_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
    logger.addHandler(logger_handler)
    logger.setLevel(DEBUG_LOGGING_MAP.get(verbose, logging.DEBUG))


@cli.command()
@click.option('--stack', '-s', 'stack_name', required=True,
              help='AWS CloudFormation Stack name.')
@click.option('--port', '-P', default=2374,
              help='Local port to use for remote docker connection')
@click.option('--key-path', '-i',
              help='Path to private key for SSH connection to docker host')
@click.option('--runtime', default='flashgames',
              help='Runtime ID to start. (See universe/runtimes/__init__ for a starting list)')
@click.option('-n', default=1,
              help='Number of environments to start')
def start(stack_name, port, key_path, runtime, n):
    stack = get_stack(stack_name)
    with docker_cli(stack.docker_ip, port, key_path=key_path) as cli:
        filepath = build_compose(cli, runtime, n)
        start_compose(cli, filepath)

        containers = get_runtime_containers(cli, runtime)
        instances = get_worker_instances(stack.name)

        endpoints = []
        for name, container in containers.items():
            for port in container['ports']:
                if port['PrivatePort'] == VNC_PORT:
                    vnc_port = port['PublicPort']
                elif port['PrivatePort'] == REWARDER_PORT:
                    rewarder_port = port['PublicPort']
            if container['host'] in instances:
                host_ip = instances[container['host']]['public_ip']
                endpoints.append(
                    'vnc://{}:{}+{}'.format(host_ip, vnc_port, rewarder_port))
            else:
                logger.warn(
                    'Container %s on unknown host %s', name, container['host'])

        print('Environments started.')
        print('Remotes:')
        for endpoint in endpoints:
            print('\t{}'.format(endpoint))


@cli.command()
@click.option('--stack', '-s', 'stack_name', required=True,
              help='AWS CloudFormation Stack name.')
@click.option('--port', '-P', default=2374,
              help='Local port to use for remote docker connection')
@click.option('--key-path', '-i',
              help='Path to private key for SSH connection to docker host')
@click.option('--runtime', default='flashgames',
              help='Runtime ID to stop. (See universe/runtimes/__init__ for a starting list)')
def stop(stack_name, port, key_path, runtime):
    stack = get_stack(stack_name)
    with docker_cli(stack.docker_ip, port, key_path=key_path) as cli:
        filepath = get_compose_file(runtime)
        stop_compose(cli, filepath)

        print('Environments stopped.')

if __name__ == '__main__':
    cli()
